import torch
from torch import nn
import torch.optim as optim
import higher

from mlp import *
from util import *
from flow import *
from nonneg_sgd import *
from evaluation import *
from tikhonov import *

def get_fold_flow_data(G, train, valid):
    '''
        Gets tensor representation of validation flows
        and matrix to map from test flows to validation flows
    '''
    flows = np.zeros((len(valid),1))
        
    i = 0
    for e in valid:
        flows[i][0] = valid[e]
        
        i = i + 1
    
    use_cuda = torch.cuda.is_available()
    
    if use_cuda:
        flows = torch.cuda.FloatTensor(flows)
    else:
        flows = torch.FloatTensor(flows)
    
    row = []
    col = []
    data = []
    
    j = 0
    for ej in G.edges():
        if ej not in train:
            i = 0
            for ei in valid:
                if ei == ej:
                    row.append(i)
                    col.append(j)
                    data.append(1.)

                i = i + 1
                
            j = j + 1
    
    mapp = coo_matrix((data, (row,col)), shape=(len(valid), G.number_of_edges()-len(train)))
    mapp = sparse_tensor_from_coo_matrix(mapp)
    
    return flows, mapp

def get_out_of_train_features_and_prior(G, features, priors, train, index):
    '''
        Feature representation for test edges in MLPLearnFlow
    '''
    n_features = features[list(features.keys())[0]].shape[0]
    
    feat = np.zeros((G.number_of_edges()-len(train), n_features))
    prior_flows = np.zeros((G.number_of_edges()-len(train), 1))
    
    for e in G.edges():
        if e not in train:
            i = index[e]
            feat[i] = features[e]
            prior_flows[i,0] = priors[e]
        
    use_cuda = torch.cuda.is_available()
    
    if use_cuda:
        feat = torch.cuda.FloatTensor(feat)
        prior_flows = torch.cuda.FloatTensor(prior_flows)
    else:
        feat = torch.FloatTensor(feat)
        prior_flows = torch.FloatTensor(prior_flows)
    
    return feat, prior_flows

class MLPLearnFlow(nn.Module):
    '''
        Learns an MLP whose output is the optimal regularizer
        for a flow estimation problem using bilevel optimization (higher library)
    '''
    def __init__(self, G, features, priors, lamb, net, n_folds, inner_n_iter_train, inner_n_iter_pred, outer_n_iter, inner_lr, outer_lr, nonneg=False, early_stop=10):
        super(MLPLearnFlow, self).__init__()
        
        self.early_stop = early_stop
        self.inner_n_iter_train = inner_n_iter_train
        self.inner_n_iter_pred = inner_n_iter_pred
        self.outer_n_iter = outer_n_iter
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        self.n_folds = n_folds
        
        self.G = G
        self.features = features
        self.priors = priors
        self.nonneg = nonneg
        self.net = net
        self.lamb = lamb
        
        self.use_cuda = torch.cuda.is_available()
        
        if self.use_cuda:
            self.cuda()
                
    def forward(self, A, b, x_init, reg_vec, x_prior, mapp, int_test_flows, verbose=False):
        '''
            Inner problem
        '''
        x = x_init.clone().detach().requires_grad_(True)
        
        if self.nonneg:
            if self.use_cuda:
                inner_opt = higher.get_diff_optim(NONNEGSGD([x], lr=self.inner_lr),[x], device='cuda:0')
            else:
                inner_opt = higher.get_diff_optim(NONNEGSGD([x], lr=self.inner_lr),[x])
        else:
            if self.use_cuda:
                inner_opt = higher.get_diff_optim(torch.optim.SGD([x], lr=self.inner_lr),[x], device='cuda:0')
            else:
                inner_opt = higher.get_diff_optim(torch.optim.SGD([x], lr=self.inner_lr),[x])
        
        
        tk = Tikhonov(A, b, 0, 0., self.nonneg)
        loss_func = nn.MSELoss()
        losses = []
        
        for epoch in range(self.inner_n_iter_train):
            tik_loss = tk(x, reg_vec, x_prior)  #tikhonov loss
            x, = inner_opt.step(tik_loss, params=[x])
            valid_loss = loss_func(torch.sparse.mm(mapp, x), int_test_flows) #validation loss
            losses.append(valid_loss.item())
                    
            if epoch % 100 == 0 and verbose is True:
                print("epoch: ", epoch," inner loss = ", valid_loss.item())
                
            if epoch > self.early_stop and losses[-1] > np.mean(losses[-(self.early_stop+1):-1]):
                if verbose is True:
                    print("Early stopping...")
                break
        
        return x, valid_loss
            
    def train(self, train_flows, valid_flows, verbose=False):
        '''
            Outer problem
        '''
        int_folds = generate_folds({**train_flows, **valid_flows}, self.n_folds) #No extra validation
        
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.outer_lr)
        train_losses = []

        X = []

        if self.use_cuda:
            for (int_train,int_test) in int_folds:
                X.append(torch.zeros((self.G.number_of_edges()-len(int_train),1), dtype=torch.float, device='cuda:0'))
        else:
            for (int_train,int_test) in int_folds:
                X.append(torch.zeros((self.G.number_of_edges()-len(int_train),1), dtype=torch.float))

        
        for epoch in range(self.outer_n_iter):
            self.optimizer.zero_grad()
            
            if self.use_cuda:
                train_loss = torch.zeros(1, device='cuda:0')
            else:
                train_loss = torch.zeros(1)
            
            f = 0
            for (int_train,int_test) in int_folds:
                if verbose is True:
                    print("fold ", f)
                
                A, b, index = lsq_matrix_flow(self.G, int_train)
                int_test_flows, mapp = get_fold_flow_data(self.G, int_train, int_test)
                feat, prior_flows = get_out_of_train_features_and_prior(self.G, self.features, self.priors, int_train, index)
                
                reg_vec = self.lamb * self.net.forward(feat)
                x, loss = self(A, b, X[f], reg_vec, prior_flows, mapp, int_test_flows, verbose)  #forward
                X[f] = x.clone().detach()
                train_loss = train_loss + loss
                loss.backward()
                self.optimizer.step()
                f = f + 1
            
            train_losses.append(train_loss.item())
            
            print("epoch: ", epoch, " outer train loss = ", train_loss.item())
                
            if epoch > self.early_stop and train_losses[-1] > np.mean(train_losses[-(self.early_stop+1):-1]):
                if verbose is True:
                    print("Early stopping...")
                break
        
        A, b, self.index = lsq_matrix_flow(self.G, {**train_flows, **valid_flows})
        feat, prior_flows = get_out_of_train_features_and_prior(self.G, self.features, self.priors, {**train_flows, **valid_flows}, self.index)
        reg_vec = self.lamb * self.net.forward(feat).detach()
        x_init = initialize_flows(A.shape[1])
        self.tk = Tikhonov(A, b, self.inner_n_iter_pred, self.inner_lr, self.nonneg)
        self.tk.train(reg_vec, x_init, prior_flows, verbose=False)

